{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Otto商品分类——线性SVM\n",
    "\n",
    "我们以Kaggle 2015年举办的Otto Group Product Classification Challenge竞赛数据为例，分别调用\n",
    "缺省参数LinearSVC、\n",
    "LinearSVC + CV进行参数调优（手动实现循环）。\n",
    "\n",
    "Otto数据集是著名电商Otto提供的一个多类商品分类问题，类别数=9. 每个样本有93维数值型特征（整数，表示某种事件发生的次数，已经进行过脱敏处理）。 竞赛官网：https://www.kaggle.com/c/otto-group-product-classification-challenge/data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 首先 import 必要的模块\n",
    "import pandas as pd \n",
    "import numpy as np\n",
    "\n",
    "#竞赛的评价指标为logloss，但LinearSVC不支持概率\n",
    "#所以在这个例子中我们用正确率accuracy_score作为模型选择的度量\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取数据 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# 读取数据\n",
    "# path to where the data lies\n",
    "dpath = 'E:\\Py Pro\\data\\\\'\n",
    "\n",
    "\n",
    "# 采用原始特征 + tf_idf特征\n",
    "#原始特征 + tf_idf特征对线性SVM训练还是很快，RBF核已慢得不行\n",
    "# RBF核只用tf_idf特征\n",
    "train1 = pd.read_csv(dpath +\"Otto_FE_train_org.csv\")\n",
    "train2 = pd.read_csv(dpath +\"Otto_FE_train_tfidf.csv\")\n",
    "#train = pd.read_csv(dpath +\"Otto_FE_train_tfidf.csv\")\n",
    "\n",
    "#去掉多余的id\n",
    "train2 = train2.drop([\"id\",\"target\"], axis=1)\n",
    "train =  pd.concat([train1, train2], axis = 1, ignore_index=False)\n",
    "\n",
    "\n",
    "del train1\n",
    "del train2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 61878 entries, 0 to 61877\n",
      "Columns: 188 entries, id to feat_93_tfidf\n",
      "dtypes: float64(186), int64(1), object(1)\n",
      "memory usage: 88.8+ MB\n"
     ]
    }
   ],
   "source": [
    "train.info()\n",
    "#train1.describe()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将类别字符串变成数字\n",
    "# drop ids and get labels\n",
    "y_train = train['target']   #形式为Class_x\n",
    "X_train = train.drop([\"id\", \"target\"], axis=1)\n",
    "\n",
    "#保存特征名字以备后用（可视化）\n",
    "feat_names = X_train.columns \n",
    "\n",
    "#sklearn的学习器大多之一稀疏数据输入，模型训练会快很多\n",
    "from scipy.sparse import csr_matrix\n",
    "X_train = csr_matrix(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anacanda\\lib\\site-packages\\sklearn\\model_selection\\_split.py:2179: FutureWarning: From version 0.21, test_size will always complement train_size unless both are specified.\n",
      "  FutureWarning)\n"
     ]
    }
   ],
   "source": [
    "# 训练样本6w+，交叉验证太慢，用train_test_split估计模型性能\n",
    "# SVM对大样本数据集支持不太好\n",
    "from sklearn.model_selection import train_test_split\n",
    "X_train_part, X_val, y_train_part, y_val = train_test_split(X_train, y_train, train_size = 10000,random_state = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 186)\n"
     ]
    }
   ],
   "source": [
    "print (X_train_part.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 默认参数的 SVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n",
       "     intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n",
       "     multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n",
       "     verbose=0)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#LinearSVC不能得到每类的概率（只有predict函数，没有predict_proba函数），在Otto数据集要求输出每类的概率，这里只是示意SVM的使用方法\n",
    "#https://xacecask2.gitbooks.io/scikit-learn-user-guide-chinese-version/content/sec1.4.html\n",
    "#1.4.1.2. 得分与概率\n",
    "#1. 生成学习器实例\n",
    "SVC1 = LinearSVC()\n",
    "\n",
    "#2. 模型训练\n",
    "SVC1.fit(X_train_part, y_train_part)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "accuracy is:  0.7596283588418983\n",
      "Classification report for classifier LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n",
      "     intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n",
      "     multi_class='ovr', penalty='l2', random_state=None, tol=0.0001,\n",
      "     verbose=0):\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "     Class_1       0.63      0.30      0.41      1615\n",
      "     Class_2       0.65      0.86      0.74     13479\n",
      "     Class_3       0.51      0.33      0.40      6706\n",
      "     Class_4       0.72      0.14      0.23      2267\n",
      "     Class_5       0.94      0.96      0.95      2312\n",
      "     Class_6       0.93      0.93      0.93     11883\n",
      "     Class_7       0.71      0.61      0.65      2399\n",
      "     Class_8       0.84      0.92      0.88      7077\n",
      "     Class_9       0.79      0.86      0.83      4140\n",
      "\n",
      "   micro avg       0.76      0.76      0.76     51878\n",
      "   macro avg       0.75      0.66      0.67     51878\n",
      "weighted avg       0.75      0.76      0.74     51878\n",
      "\n",
      "\n",
      "Confusion matrix:\n",
      "[[  487    64    14     8     8   140    48   382   464]\n",
      " [   15 11626  1435    61    74    43   102    68    55]\n",
      " [    5  4209  2223    26     9    16   155    37    26]\n",
      " [    0  1356   386   308    17   114    62    17     7]\n",
      " [    3    71     4     0  2224     3     0     4     3]\n",
      " [   71   118    24    12     4 11006   166   270   212]\n",
      " [   57   260   193     7    18   171  1462   200    31]\n",
      " [   72    84    35     2     6   207    51  6494   126]\n",
      " [   69    73     6     5    10   158    23   218  3578]]\n"
     ]
    }
   ],
   "source": [
    "#3. 在校验集上测试，估计模型性能\n",
    "y_predict = SVC1.predict(X_val)\n",
    "\n",
    "print(\"accuracy is: \",accuracy_score(y_val, y_predict))\n",
    "\n",
    "print(\"Classification report for classifier %s:\\n%s\\n\"\n",
    "      % (SVC1, classification_report(y_val, y_predict)))\n",
    "print(\"Confusion matrix:\\n%s\" % confusion_matrix(y_val, y_predict))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用原始特征 + tfidf特征的线性SVM分类性能：accuracy is： 0.76430187459599219\n",
    "\n",
    "class_1,class_3和class_4分类效果不好。\n",
    "是因为这几类样本数目少？（class_6类的样本数目也不多）。后面采用类别权重试试\n",
    "(用class_weight='balanced'效果更差了)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 线性SVM正则参数调优"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "线性SVM LinearSVC的需要调整正则超参数包括C（正则系数，一般在log域（取log后的值）均匀设置候选参数）和正则函数penalty（L2/L1） \n",
    "\n",
    "采用交叉验证，网格搜索步骤与Logistic回归正则参数处理类似，在此略。\n",
    "\n",
    "这里我们用校验集（X_val、y_val）来估计模型性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "#单组超参数情况，模型在训练集上训练，在校验集上的测试的测试性能\n",
    "def fit_grid_point_Linear(C, X_train, y_train, X_val, y_val):\n",
    "    \n",
    "    # 在训练集上训练SVC\n",
    "    SVC2 =  LinearSVC( C = C)\n",
    "    SVC2 = SVC2.fit(X_train, y_train)\n",
    "    \n",
    "    # 在校验集上返回accuracy\n",
    "    accuracy = SVC2.score(X_val, y_val)\n",
    "    \n",
    "    print(\"C= {} : accuracy= {} \" .format(C, accuracy))\n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C= 0.1 : accuracy= 0.750183121939936 \n",
      "C= 1.0 : accuracy= 0.7596283588418983 \n",
      "C= 10.0 : accuracy= 0.7618836501021627 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anacanda\\lib\\site-packages\\sklearn\\svm\\base.py:922: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n",
      "  \"the number of iterations.\", ConvergenceWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C= 100.0 : accuracy= 0.760919850418289 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No handles with labels found to put in legend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "C= 1000.0 : accuracy= 0.6419098654535641 \n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xu8VOV97/HPl7sGKbddowJCLFDxEi9bjLU2QmNCqIoao6ARgdnantScmpdJjkl7TlrT80ransaTNran1itq8YI3olZK8RJrNbLxEgW8IEbdiAEpYlC5/84fa+0wbPbeM2z2mjUz+/t+vea1Z9Y8M/PbAzPfvZ61nudRRGBmZtaZXnkXYGZm1c9hYWZmJTkszMysJIeFmZmV5LAwM7OSHBZmZlaSw8LMzEpyWJiZWUkOCzMzK6lP3gV0l+HDh8fo0aPzLsPMrKYsXbr0vYhoKNWubsJi9OjRNDc3512GmVlNkfRmOe3cDWVmZiU5LMzMrCSHhZmZlVQ3xyzMzHq6bdu20dLSwubNm/e4b8CAAYwYMYK+fft26bkdFmZmdaKlpYUDDjiA0aNHI+nX2yOC9evX09LSwpgxY7r03O6GMjOrE5s3b2bYsGG7BQWAJIYNG9buHke5HBZmZnWkbVCU2l4ud0OZddHOnbBpU3L51a86/xkBffsml379Or9e6v72rvfuDfv4XWDWqUzDQtIU4EdAb+C6iPhBm/uvBialN/cHfjMiBqf3jQKuA0YCAUyNiF9kWa/VrwjYurX0l/re/Pzoo7x/q91lFURdub63j3PYVb/MwkJSb+Aa4DSgBVgiaUFELG9tExFfL2r/NeDYoqeYC/zviFgkaSCwM6tarfrs3Akffti9X+7bt5f32r16wQEHJJeBA3f9HDVq99vl/hw4MPki3LZt12Xr1q5d7+rj2rv+wQflv87OjD990u4h8tWvwl/+ZbavWa8iot0up4jYp+fNcs9iIrAyIlYBSLodmAYs76D9DOC7adsJQJ+IWAQQEZsyrNO6wZYt3felvmlTEhTlGjBgzy/oIUNg5Mi9/2I/4IDk+bL4K7d//+RSi3bu7P6w6mjbk0/C1VfDt74Fgwbl/ZvXlgEDBrB+/fo9DnK3ng01YMCALj93lmFxCPB20e0W4MT2Gko6FBgDPJJuGge8L+medPu/A1dGxI42j7sUuBRg1KhR3Vq87RIBr78Ozc2wZAk8+yysW5d8sbd+uW/bVt5z9erV/pf0yJFd/6u9j4+8Za5Xr8qF3dNPw0knwR13wCWXZP969WTEiBG0tLSwbt26Pe5rHWfRVVl+zNr726yj/aDpwPyiMOgDnELSLfUWcAcwC7h+tyeLuBa4FqCxsXHf9rEMSILh7beTUGhu3nV5//3k/gED4NOfhvHj9+6LvfX6fvu5b9o6d+KJMGECXH+9w2Jv9e3bt8vjKErJMixaSA5OtxoBvNNB2+nAH7d57HNFXVj3AZ+hTVjYvnv33V3B0Pqz9Y+SPn3g6KPh/POhsTG5HHFE0q9slhUJCgW44gp46SU48si8KzLINiyWAGMljQFWkwTCBW0bSRoPDAGeavPYIZIaImIdMBnw/OP7aP36XXsKrcGwenVyX69eSRCcfnoSCiecAEcdlexJmFXaRRfBlVcmexdXX513NQYZhkVEbJd0GbCQ5NTZGyJimaSrgOaIWJA2nQHcHkWH6iNih6RvAIuVHKVZCvxzVrXWo40bk2MLxd1Jb7yx6/7x4+HUU3cFwzHHwCc+kVu5ZrtpaIBp0+CWW+AHP6jdEwPqifb1dKpq0djYGD118aMPP4Tnntt9r+HVV3fdP2bMrlBobITjjoPf+I386jUrx8MPwxe/CHfeCV/+ct7V1C9JSyOisVQ7n0dSY7ZsgRde2L0rafnyXefBH3JIEgozZ+46zjBsWL41m3XFaaclZ8ldd53Doho4LKrYtm2wbNnuwfDii7tOU21oSILhnHN2BcNBB+Vbs1l36d0bZs+G730P3nwTDj0074p6NodFldixA155ZfdjDM8/D62TRA4enITBFVfs6k4aOdKnoVp9aw2Lm26C734372p6Nh+zyEEErFy5+zGGZ5/dNWp54MDkuEJrKDQ2wmGHORisZ/r855M/pFatSvY2rHv5mEWViIC33tq9K2np0t0HuR17LMyZs+sg9Lhx/lCYtSoUYPp0WLw4CQ7Lh8Oim61Zs3swFA9y69t31yC31r2GCRM8yM2sM2edBUOHJmMuHBb5cVjsg/fe231KjCVL4J10jHrxILfWYDj6aJ8vbra3+vdPBun9wz8kn7nhw/OuqGdyWJRp48ak+6h4r+EXv9h1//jxMHnyrmMMxx4L+++fW7lmdaVQgB/9CG69FS6/PO9qeiYf4G5H6yC34q6ktoPcWvcWTjghCQYPcjPL1sSJ8PHH8POf+2SP7uQD3GXatm3PYCge5DZiRBIKM2cmwXD88R7kZpaHpib4wz+EZ55JZqa1yurxYbF27a7/eK2D3L70pV3dSZ/8ZL71mVli+nT4+teTA90Oi8rr8WFx8MFw//1JV9KIEd69NatWgwYl037Mmwc//GEyHskqp1feBeRNgjPP9Ghos1rQ1JSszHjXXXlX0vP0+LAws9px8snJmYfXexm0inNYmFnNkJLZDp58ElasyLuansVhYWY1ZebMZMnfG27Iu5KexWFhZjXlk59MZkaYOxe2bs27mp7DYWFmNadQSE57f+CBvCvpORwWZlZzpkxJTnv3ge7KcViYWc3p0wdmzUrW6W5pybuansFhYWY1ac6cZFqem27Ku5KewWFhZjXpsMNg0qTkrKjWudwsOw4LM6tZhQK88QY89ljeldQ/h4WZ1axzzoHBg+G66/KupP5lGhaSpkh6RdJKSVe2c//Vkp5PL69Ker/N/YMkrZb04yzrNLPatN9+cOGFcM89sGFD3tXUt8zCQlJv4Brgi8AEYIakCcVtIuLrEXFMRBwD/D1wT5un+R7weFY1mlntKxRgyxa47ba8K6lvWe5ZTARWRsSqiNgK3A5M66T9DGBe6w1JxwMHAv+WYY1mVuOOPRaOOy7piqqThT+rUpZhcQjwdtHtlnTbHiQdCowBHklv9wL+FvhmZy8g6VJJzZKa161b1y1Fm1ntKRTghRfg2WfzrqR+ZRkW7a0O0VHuTwfmR8SO9PZXgYci4u0O2idPFnFtRDRGRGNDQ8M+lGpmteyCC2DAAI/ozlKWYdECjCy6PQJ4p4O20ynqggJOAi6T9Avg/wAzJf0giyLNrPYNHgznnpsct/joo7yrqU9ZhsUSYKykMZL6kQTCgraNJI0HhgBPtW6LiAsjYlREjAa+AcyNiD3OpjIza1UowAcfwN13511JfcosLCJiO3AZsBBYAdwZEcskXSXpzKKmM4DbI3xoysy67rOfTUZ1uysqG6qX7+jGxsZobm7Ouwwzy9H3vw/f+Q68+iqMHZt3NbVB0tKIaCzVziO4zaxuXHwx9OrlVfSy4LAws7px8MEwdWoyE+327XlXU18cFmZWV5qa4N134aGH8q6kvjgszKyuTJ0KBx7oA93dzWFhZnWlb9/k2MWDD8KaNXlXUz8cFmZWdwoF2LEDbr4570rqh8PCzOrOuHFwyinJWVF1Mjogdw4LM6tLhQK89ho88UTeldQHh4WZ1aVzz4VBg7yKXndxWJhZXfrEJ2DGDJg/HzZuzLua2uewMLO6VSjAxx/DvHml21rnHBZmVrcaG+Hoo90V1R0cFmZWt6Rk72Lp0mQlPes6h4WZ1bULL4R+/Tyie185LMysrg0bBuecA7feCps3511N7XJYmFndKxRgwwa49968K6ldDgszq3uTJ8Po0e6K2hcOCzOre716wZw5sHgxrFqVdzW1yWFhZj3CrFnJ2VE33ph3JbXJYWFmPcLIkTBlShIWO3bkXU3tcViYWY9RKMDq1bBwYd6V1B6HhZn1GGecAQ0NPtDdFQ4LM+sx+vWDmTNhwQL45S/zrqa2OCzMrEcpFGD7drjllrwrqS2ZhoWkKZJekbRS0pXt3H+1pOfTy6uS3k+3HyPpKUnLJP1c0vlZ1mlmPcfhh8NJJyVdUV5Fr3yZhYWk3sA1wBeBCcAMSROK20TE1yPimIg4Bvh74J70ro+AmRFxBDAF+L+SBmdVq5n1LE1N8PLL8J//mXcltSPLPYuJwMqIWBURW4HbgWmdtJ8BzAOIiFcj4rX0+jvAWqAhw1rNrAc57zwYONAHuvdGlmFxCPB20e2WdNseJB0KjAEeaee+iUA/4PV27rtUUrOk5nXr1nVL0WZW/wYOhPPPhzvugA8+yLua2pBlWKidbR31EE4H5kfEbkNlJB0E3ALMjoidezxZxLUR0RgRjQ0N3vEws/I1NcFHHyWBYaVlGRYtwMii2yOAdzpoO520C6qVpEHAg8CfRcTTmVRoZj3WiSfChAnuiipXlmGxBBgraYykfiSBsKBtI0njgSHAU0Xb+gH3AnMj4q4MazSzHqp1Fb2f/QxeeinvaqpfZmEREduBy4CFwArgzohYJukqSWcWNZ0B3B6x20ls5wG/B8wqOrX2mKxqNbOe6aKLoG9f712UQ1EnJxo3NjZGc3Nz3mWYWY358pfh0UeTOaP698+7msqTtDQiGku18whuM+vRCgVYvz6ZAsQ65rAwsx7ttNOS6cuvuy7vSqqbw8LMerTevWH2bFi0CN58M+9qqpfDwsx6vNmzk5833ZRrGVWtrLCQdLekP5DkcDGzujN6NHzuc3DDDV5FryPlfvn/I3AB8JqkH0j67QxrMjOruEIB3noLFi/Ou5LqVFZYRMS/R8SFwHHAL4BFkv5T0mxJfbMs0MysEs46C4YO9ZiLjpTdrSRpGDALaAKeA35EEh6LMqnMzKyC+vdPBundey+8917e1VSfco9Z3AM8AewPnBERZ0bEHRHxNWBglgWamVVKoQDbtsGtt+ZdSfUpd8/ixxExISK+HxFriu8oZ+SfmVktOOooOOEEr6LXnnLD4vDileokDZH01YxqMjPLTVNTMrHgM8/kXUl1KTcsLomI91tvRMQG4JJsSjIzy8/06bD//j7Q3Va5YdFL0q8XM0rX1+6XTUlmZvkZNCiZXHDePNi0Ke9qqke5YbEQuFPS70uaTLJQ0cPZlWVmlp+mpiQo7vJqOr9Wblj8D5L1sf8b8MfAYuBbWRVlZpank0+G8ePdFVWs3EF5OyPiHyPi3Ij4UkT8U9v1ss3M6kXrKnpPPgkvv5x3NdWh3HEWYyXNl7Rc0qrWS9bFmZnlZeZM6NPHexetyu2GupFkfqjtwCRgLnBLVkWZmeXtwAPh9NNh7txkoF5PV25Y7BcRi0mWYX0zIv4cmJxdWWZm+WtqgrVr4YEH8q4kf+WGxeZ0evLXJF0m6WzgNzOsy8wsd1/4Ahx8sFfRg/LD4nKSeaH+O3A88BXg4qyKMjOrBn36wKxZ8PDDsHp13tXkq2RYpAPwzouITRHREhGz0zOinq5AfWZmuZozB3bu9Cp6JcMiPUX2+OIR3GZmPcVhh8GkSclZUTt35l1NfsrthnoOuF/SRZLOab1kWZiZWbUoFOCNN+Cxx/KuJD/lhsVQYD3JGVBnpJfTSz1I0hRJr0haKenKdu6/WtLz6eVVSe8X3XexpNfSi4+PmFluzjkHBg/u2Qe6+5TTKCJm7+0Tp8c6rgFOA1qAJZIWRMTyouf9elH7rwHHpteHAt8FGoEAlqaP3bC3dZiZ7av99oMLL0zCYsMGGDIk74oqr9wR3DdKuqHtpcTDJgIrI2JVRGwFbgemddJ+BskEhQBfABZFxH+lAbEImFJOrWZmWSgUYMsWuO22vCvJR7ndUA8AD6aXxcAgoNTkvYcAbxfdbkm37UHSocAYkskKy36spEslNUtqXrduXRm/hplZ1xx7LBx3XLJ30RNX0St3IsG7iy63AecBR5Z4WHtnT3X0Fk8H5hdNTljWYyPi2ohojIjGhoaGEuWYme2bQgFeeAGefTbvSiqv3D2LtsYCo0q0aQFGFt0eAbzTQdvp7OqC2tvHmplVxAUXwIABPXNywXKPWfxK0getF+AnJGtcdGYJMFbSGEn9SAJhQTvPPR4YAjxVtHkh8Pl0re8hwOfTbWZmuRk8GM49Nzlu8dFHeVdTWeV2Qx0QEYOKLuMi4u4Sj9kOXEbyJb8CuDMilkm6StKZRU1nALdH7OoFjIj/Ar5HEjhLgKvSbWZmuSoU4IMP4O5OvwHrj6KMIzXpxIGPRMTG9PZg4NSIuC/j+srW2NgYzc3NeZdhZnUuAsaOhREj6mOQnqSlEdFYql25xyy+2xoUABHxPsk4CDOzHqV1Fb3HH4fXXsu7msopNyzaa1fWgD4zs3pz8cXQqxfcUGq0WR0pNyyaJf1Q0mGSPiXpamBploWZmVWrgw+GqVOTmWi3b8+7msooNyy+BmwF7gDuBD4G/jiroszMql1TE7z7Ljz0UN6VVEa5c0N9COwxEaCZWU81dWqyTvf118OZZ5ZuX+vKHWexKD0DqvX2EEke92BmPVbfvsmxiwcfhDVr8q4me+V2Qw1Pz4ACIJ3cz2twm1mPVijAjh1w8815V5K9csNip6RfT+8haTQdz/NkZtYjjBsHp5ySnBVV75MLlhsWfwr8h6RbJN0CPA58O7uyzMxqQ6GQjLd44om8K8lWudN9PEyyENErJGdEXUFyRpSZWY927rkwaFD9r6JX7gHuJpJ1LK5IL7cAf55dWWZmteETn4AZM2D+fNi4sXT7WlVuN9SfACcAb0bEJJLlT73akJkZyZiLjz+GefNKt61V5YbF5ojYDCCpf0S8DIzPriwzs9px/PFw9NH13RVVbli0pOMs7gMWSbofL0ZkZgbsmlxw6dJkJb16VO4B7rMj4v2I+HPgfwLXA2dlWZiZWS35ylegf//6XUVvr5dVjYjHI2JBRGzNoiAzs1o0dCicfTbceits3px3Nd2vq2twm5lZG4UCbNgA996bdyXdz2FhZtZNJk+G0aPrsyvKYWFm1k169YI5c2DxYli1Ku9qupfDwsysG82alZwddeONeVfSvRwWZmbdaORImDIlCYsdO/Kupvs4LMzMulmhAKtXw8I6WvXHYWFm1s3OOAMaGurrQHemYSFpiqRXJK2U1O6yrJLOk7Rc0jJJ/1K0/a/TbSsk/Z0kZVmrmVl36dcPZs6EBQtg7dq8q+kemYWFpN7ANcAXgQnADEkT2rQZS7IuxskRcQRwebr9d4CTgaOBI0kmMfxsVrWamXW3QgG2b4e5c/OupHtkuWcxEVgZEavS0d63A9PatLkEuCZdppWIaM3gAAYA/YD+QF/glxnWambWrQ4/HE46KemKqodV9LIMi0OAt4tut6Tbio0Dxkl6UtLTkqYARMRTwKPAmvSyMCJWZFirmVm3a2qCl1+Gp57Ku5J9l2VYtHeMoW2+9gHGAqcCM4DrJA2W9FvA4cAIkoCZLOn39ngB6VJJzZKa163z8hpmVl3OOw8GDqyPqcuzDIsWYGTR7RHsOa15C3B/RGyLiDdIlm0dC5wNPB0RmyJiE/CvwGfavkBEXBsRjRHR2NDQkMkvYWbWVQMHwvnnw513wq9+lXc1+ybLsFgCjJU0RlI/YDqwoE2b+4BJAJKGk3RLrQLeAj4rqY+kviQHt90NZWY1p6kJPvwQ7rgj70r2TWZhERHbgcuAhSRf9HdGxDJJV0k6M222EFgvaTnJMYpvRsR6YD7wOvAi8ALwQkT8JKtazcyycuKJMGFC7XdFKerhMD3Q2NgYzc3NeZdhZraHH/4QrrgCXnwRjjwy72p2J2lpRDSWaucR3GZmGbvoIujbt7ZHdDsszMwy1tAA06bBLbfAli15V9M1DgszswooFGD9+mQKkFrksDAzq4DTTkumL6/VA90OCzOzCujdG2bPhkWL4M03865m7zkszMwqZPbs5OdNN+VaRpc4LMzMKmT0aPjc5+CGG2pvFT2HhZlZBRUK8NZbsHhx3pXsHYeFmVkFnXUWDB1ae2MuHBZmZhXUv38ySO/ee+G99/KupnwOCzOzCisUYNs2uPXWvCspn8PCzKzCjjoKJk6srVX0HBZmZjkoFOCll+CZZ/KupDwOCzOzHEyfDvvvXzsHuh0WZmY5GDQoWXZ13jzYtCnvakpzWJiZ5aRQSILirrvyrqQ0h4WZWU5OPhnGj6+NriiHhZlZTqRk7+LJJ+Hll/OupnMOCzOzHM2cCX36VP/ehcPCzCxHBx4Ip58Oc+cmA/WqlcPCzCxnTU2wdi088EDelXTMYWFmlrMvfAEOPri6V9FzWJiZ5axPH5g1Cx5+GFavzrua9jkszMyqwJw5sHNn9a6il2lYSJoi6RVJKyVd2UGb8yQtl7RM0r8UbR8l6d8krUjvH51lrWZmeTrsMJg0KTkraufOvKvZU2ZhIak3cA3wRWACMEPShDZtxgLfBk6OiCOAy4vungv8TUQcDkwE1mZVq5lZNSgU4I034LHH8q5kT1nuWUwEVkbEqojYCtwOTGvT5hLgmojYABARawHSUOkTEYvS7Zsi4qMMazUzy90558DgwdU55iLLsDgEeLvodku6rdg4YJykJyU9LWlK0fb3Jd0j6TlJf5PuqZiZ1a399oMLL4S774YNG/KuZndZhoXa2dZ2mY8+wFjgVGAGcJ2kwen2U4BvACcAnwJm7fEC0qWSmiU1r1u3rvsqNzPLSaEAW7bAbbflXcnusgyLFmBk0e0RwDvttLk/IrZFxBvAKyTh0QI8l3ZhbQfuA45r+wIRcW1ENEZEY0NDQya/hJlZJR17LBx3XDLmoppW0csyLJYAYyWNkdQPmA4saNPmPmASgKThJN1Pq9LHDpHUmgCTgeUZ1mpmVjUKBXjhBXj22bwr2SWzsEj3CC4DFgIrgDsjYpmkqySdmTZbCKyXtBx4FPhmRKyPiB0kXVCLJb1I0qX1z1nVamZWTS64AAYMqK4D3Ypq2s/ZB42NjdHc3Jx3GWZm3eKii+AnP4F33kmWX82KpKUR0ViqnUdwm5lVoUIBNm5MzoyqBg4LM7Mq9NnPJqO6q6UrymFhZlaFWlfRe/xxeO21vKtxWJiZVa2LL4ZeveCGG/KuxGFhZla1Dj4Y/uAPkplot2/PtxaHhZlZFSsU4N134aGH8q3DYWFmVsWmTk3W6c77QLfDwsysivXtm6yi9+CDsGZNfnU4LMzMqtycObBjB9x8c341OCzMzKrcuHFwyinJWVF5TbrhsDAzqwFNTcl4iyeeyOf1HRZmZjXg3HNh0KBk6vI8OCzMzGrA/vvDjBkwf34yZ1SlOSzMzGpEUxN8/DHMm1f513ZYmJnViOOPh6OPzqcrymFhZlYjWicXXLo0WUmvkhwWZmY15Ctfgf79Kz+i22FhZlZDhg6Fs8+GW2+FzZsr97oOCzOzGlMowIYNcO+9lXtNh4WZWY2ZPBlGj65sV5TDwsysxvTqlcwXtXgxrFpVodeszMuYmVl3mjUrOTvqxhsr83oOCzOzGjRyJEyZkoTFjh3Zv57DwsysRhUKsHo1LFyY/WtlGhaSpkh6RdJKSVd20OY8ScslLZP0L23uGyRptaQfZ1mnmVktOuMMaGiozIHuPlk9saTewDXAaUALsETSgohYXtRmLPBt4OSI2CDpN9s8zfeAx7Oq0cyslvXrB5dfDh99lKxzIWX3WpmFBTARWBkRqwAk3Q5MA5YXtbkEuCYiNgBExNrWOyQdDxwIPAw0ZlinmVnN+s53KvM6WXZDHQK8XXS7Jd1WbBwwTtKTkp6WNAVAUi/gb4FvZlifmZmVKcs9i/Z2iNouCNgHGAucCowAnpB0JPAV4KGIeFud7FdJuhS4FGDUqFHdULKZmbUny7BoAUYW3R4BvNNOm6cjYhvwhqRXSMLjJOAUSV8FBgL9JG2KiN0OkkfEtcC1AI2NjTmtTGtmVv+y7IZaAoyVNEZSP2A6sKBNm/uASQCShpN0S62KiAsjYlREjAa+AcxtGxRmZlY5mYVFRGwHLgMWAiuAOyNimaSrJJ2ZNlsIrJe0HHgU+GZErM+qJjMz6xpF1EfvTWNjYzQ3N+ddhplZTZG0NCJKnnHqEdxmZlaSw8LMzEqqm24oSeuAN/fhKYYD73VTOd3Jde0d17V3XNfeqce6Do2IhlKN6iYs9pWk5nL67SrNde0d17V3XNfe6cl1uRvKzMxKcliYmVlJDotdrs27gA64rr3juvaO69o7PbYuH7MwM7OSvGdhZmYl9diwkPTldHW+nZI6PIugnNX+urmuoZIWSXot/Tmkg3Y7JD2fXtrOudWd9XT6+0vqL+mO9P6fSRqdVS17UdMsSeuK3p+mrGtKX/cGSWslvdTB/ZL0d2ndP5d0XJXUdaqkjUXv1/+qUF0jJT0qaUX6WfyTdtpU/D0rs66Kv2eSBkh6RtILaV1/0U6b7D6PEdEjL8DhwHjgMaCxgza9gdeBTwH9gBeACRnX9dfAlen1K4G/6qDdpgq8RyV/f+CrwP9Lr08H7qiCmmYBP87h/9TvAccBL3Vw/1TgX0mm7/8M8LMqqetU4IEc3q+DgOPS6wcAr7bzb1nx96zMuir+nqXvwcD0el/gZ8Bn2rTJ7PPYY/csImJFRLxSotmvV/uLiK1A62p/WZoG3Jxevxk4K+PX60w5v39xvfOB31dni5BUpqZcRMRPgf/qpMk0khmUIyKeBgZLOqgK6spFRKyJiGfT678imXC07QJpFX/Pyqyr4tL3YFN6s296aXvQObPPY48NizKVs9pfdzswItZA8p8WaLsueasBkprTFQazCpRyfv9ft4lkpuGNwLCM6im3JoAvpd0W8yWNbOf+POTx/6lcJ6XdG/8q6YhKv3jaXXIsyV/LxXJ9zzqpC3J4zyT1lvQ8sBZYFBEdvl/d/XnMcvGj3En6d+CT7dz1pxFxfzlP0c62fT59rLO69uJpRkXEO5I+BTwi6cWIeH1fa2ujnN8/k/eoE+W83k+AeRGxRdIfkfylNTnDmspV6feqXM+STPmwSdJUknVmxlbqxSUNBO4GLo+ID9re3c5DKvKelagrl/csInYAx0gaDNwr6ciIKD4Wldn7VddhERGf28enKGe1v73WWV2SfinpoIhYk+5ur+3gOd5Jf6449O71AAADo0lEQVSS9BjJXz/dHRblrnY4EmiR1Af4DbLt8ihZU+y+Jso/A3+VYT17I5P/T/uq+IswIh6S9A+ShkdE5nMgSepL8oV8W0Tc006TXN6zUnXl+Z6lr/l++rmfAhSHRWafR3dDda6c1f662wLg4vT6xcAee0CShkjqn14fDpwMLM+glnJ+/+J6zwUeifToWkZK1tSmT/tMkj7narAAmJme4fMZYGNrl2OeJH2ytV9b0kSS74XMFyFLX/N6YEVE/LCDZhV/z8qpK4/3TFJDukeBpP2AzwEvt2mW3eexkkfzq+kCnE2SwluAXwIL0+0HAw8VtZtKcjbE6yTdV1nXNQxYDLyW/hyabm8Erkuv/w7wIsmZQC8ChQzr2eP3B64CzkyvDwDuAlYCzwCfqsB7VKqm7wPL0vfnUeC3K/R/ah6wBtiW/t8qAH8E/FF6v4Br0rpfpIOz8HKo67Ki9+tp4HcqVNfvknSR/Bx4Pr1Mzfs9K7Ouir9nwNHAc2ldLwH/K91ekc+jR3CbmVlJ7oYyM7OSHBZmZlaSw8LMzEpyWJiZWUkOCzMzK8lhYbYXJG0q3arTx89PR90jaaCkf5L0ejqL6E8lnSipX3q9rgfNWm1xWJhVSDp/UO+IWJVuuo5kdO3YiDiCZLbc4ZFMkLgYOD+XQs3a4bAw64J0RPHfSHpJ0ouSzk+390qnflgm6QFJD0k6N33YhaQj8iUdBpwI/FlE7IRk6paIeDBte1/a3qwqeDfXrGvOAY4BPg0MB5ZI+inJ1CujgaNIZgxeAdyQPuZkktHUAEcAz0cyMVx7XgJOyKRysy7wnoVZ1/wuycy2OyLil8DjJF/uvwvcFRE7I+JdkulGWh0ErCvnydMQ2SrpgG6u26xLHBZmXdPRgjKdLTTzMcncPZDMK/RpSZ19BvsDm7tQm1m3c1iYdc1PgfPTxWgaSJYufQb4D5KFl3pJOpBk+c1WK4DfAohk7ZFm4C+KZi8dK2laen0YsC4itlXqFzLrjMPCrGvuJZn98wXgEeBbabfT3SQzu74E/BPJCmsb08c8yO7h0USyCNZKSS+SrL3RulbDJOChbH8Fs/J51lmzbiZpYCQrqA0j2ds4OSLeTdcgeDS93dGB7dbnuAf4dpReJ96sInw2lFn3eyBdpKYf8L10j4OI+FjSd0nWSX6rowenizrd56CwauI9CzMzK8nHLMzMrCSHhZmZleSwMDOzkhwWZmZWksPCzMxKcliYmVlJ/x/8/3+jUJ+X4wAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#需要调优的参数\n",
    "#SVM太慢，每次只调一个参数（这里只调C，penalty为‘l2'）\n",
    "C_s = np.logspace(-1, 3, 5)# logspace(a,b,N)把10的a次方到10的b次方区间分成N份  \n",
    "#penalty_s = ['l1','l2']\n",
    "\n",
    "accuracy_s = []\n",
    "for i, oneC in enumerate(C_s):\n",
    "#    for j, penalty in enumerate(penalty_s):\n",
    "    tmp = fit_grid_point_Linear(oneC, X_train_part, y_train_part, X_val, y_val)\n",
    "    accuracy_s.append(tmp)\n",
    "\n",
    "x_axis = np.log10(C_s)\n",
    "#for j, penalty in enumerate(penalty_s):\n",
    "plt.plot(x_axis, np.array(accuracy_s), 'b-')\n",
    "    \n",
    "plt.legend()\n",
    "plt.xlabel( 'log(C)' )                                                                                                      \n",
    "plt.ylabel( 'accuracy' )\n",
    "#plt.savefig('SVM_Otto.png' )\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10.0\n"
     ]
    }
   ],
   "source": [
    "### 最佳超参数\n",
    "index = np.argmax(accuracy_s, axis=None)\n",
    "Best_C = C_s[ index ]\n",
    "\n",
    "print(Best_C)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## 找到最佳参数后，用全体训练数据训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'cPickle'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-13-e396d64337dc>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[1;31m#保持模型，用于后续测试\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m \u001b[1;32mimport\u001b[0m \u001b[0mcPickle\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      9\u001b[0m \u001b[0mcPickle\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdump\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mSVC3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Otto_LinearSVC.pkl\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'wb'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'cPickle'"
     ]
    }
   ],
   "source": [
    "# SVC训练SVC，支持概率输出\n",
    "Best_C = 100\n",
    "\n",
    "SVC3 = LinearSVC(C = Best_C)\n",
    "SVC3.fit(X_train, y_train)\n",
    "\n",
    "#保持模型，用于后续测试\n",
    "import cPickle\n",
    "cPickle.dump(SVC3, open(\"Otto_LinearSVC.pkl\", 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
